{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "38c55bde",
   "metadata": {},
   "source": [
    "Copyright (c) MONAI Consortium  \n",
    "Licensed under the Apache License, Version 2.0 (the \"License\");  \n",
    "you may not use this file except in compliance with the License.  \n",
    "You may obtain a copy of the License at  \n",
    "&nbsp;&nbsp;&nbsp;&nbsp;http://www.apache.org/licenses/LICENSE-2.0  \n",
    "Unless required by applicable law or agreed to in writing, software  \n",
    "distributed under the License is distributed on an \"AS IS\" BASIS,  \n",
    "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  \n",
    "See the License for the specific language governing permissions and  \n",
    "limitations under the License."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "26fdd9ed",
   "metadata": {},
   "source": [
    "# Federated Monai MedNIST Example \n",
    "\n",
    "This demo uses to demonstrate federated learning training and validation in the case of 2D medical image registration.\n",
    "\n",
    "Based on MONAI [registration_mednist.ipynb](../../../../2d_registration/registration_mednist.ipynb) notebook and [OpenFL](https://github.com/intel/openfl) - federated learning framework."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "77d85648",
   "metadata": {},
   "source": [
    "## Setup environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8cb238c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# install workspace requirements\n",
    "! pip install -r workspace_requirements.txt"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ebff0b8e",
   "metadata": {},
   "source": [
    "## Setup imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "billion-drunk",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import tqdm\n",
    "from torch.nn import MSELoss\n",
    "\n",
    "from monai.config import USE_COMPILED, print_config\n",
    "from monai.data import DataLoader, Dataset\n",
    "from monai.transforms import (\n",
    "    Compose,\n",
    "    EnsureChannelFirstD,\n",
    "    EnsureTypeD,\n",
    "    LoadImageD,\n",
    "    RandRotateD,\n",
    "    RandZoomD,\n",
    "    ScaleIntensityRanged,\n",
    ")\n",
    "from monai.networks.blocks import Warp\n",
    "from monai.networks.nets import GlobalNet\n",
    "\n",
    "from openfl.interface.interactive_api.experiment import (\n",
    "    DataInterface,\n",
    "    FLExperiment,\n",
    "    ModelInterface,\n",
    "    TaskInterface,\n",
    ")\n",
    "from openfl.interface.interactive_api.federation import Federation\n",
    "\n",
    "\n",
    "print_config()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "246f9c98",
   "metadata": {},
   "source": [
    "## Connect to the Federation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d657e463",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a federation\n",
    "\n",
    "# please use the same identificator that was used in signed certificate\n",
    "client_id = \"api\"\n",
    "cert_dir = \"cert\"\n",
    "director_node_fqdn = \"localhost\"\n",
    "director_port = 50051\n",
    "# 1) Run with API layer - Director mTLS\n",
    "# If the user wants to enable mTLS their must provide CA root chain, and signed key pair to the federation interface\n",
    "# cert_chain = f'{cert_dir}/root_ca.crt'\n",
    "# api_certificate = f'{cert_dir}/{client_id}.crt'\n",
    "# api_private_key = f'{cert_dir}/{client_id}.key'\n",
    "\n",
    "# federation = Federation(client_id=client_id, director_node_fqdn=director_node_fqdn, director_port=director_port,\n",
    "#                        cert_chain=cert_chain, api_cert=api_certificate, api_private_key=api_private_key)\n",
    "\n",
    "# --------------------------------------------------------------------------------------------------------------------\n",
    "\n",
    "# 2) Run with TLS disabled (trusted environment)\n",
    "# Federation can also determine local fqdn automatically\n",
    "federation = Federation(\n",
    "    client_id=client_id,\n",
    "    director_node_fqdn=director_node_fqdn,\n",
    "    director_port=director_port,\n",
    "    tls=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1abebd90",
   "metadata": {},
   "outputs": [],
   "source": [
    "federation.target_shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47dcfab3",
   "metadata": {},
   "outputs": [],
   "source": [
    "shard_registry = federation.get_shard_registry()\n",
    "shard_registry"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a2a6c237",
   "metadata": {},
   "outputs": [],
   "source": [
    "# First, request a dummy_shard_desc that holds information about the federated dataset\n",
    "dummy_shard_desc = federation.get_dummy_shard_descriptor(size=10)\n",
    "dummy_shard_dataset = dummy_shard_desc.get_dataset(\"train\")\n",
    "sample, target = dummy_shard_dataset[0]\n",
    "print(sample.shape)\n",
    "print(target.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cc0dbdbd",
   "metadata": {},
   "source": [
    "## Creating a FL experiment using Interactive API"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b0979470",
   "metadata": {},
   "source": [
    "### Register dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7dda1680",
   "metadata": {},
   "outputs": [],
   "source": [
    "image_transforms = Compose(\n",
    "    [\n",
    "        LoadImageD(keys=[\"fixed_hand\", \"moving_hand\"]),\n",
    "        EnsureChannelFirstD(keys=[\"fixed_hand\", \"moving_hand\"]),\n",
    "        ScaleIntensityRanged(\n",
    "            keys=[\"fixed_hand\", \"moving_hand\"],\n",
    "            a_min=0.0,\n",
    "            a_max=255.0,\n",
    "            b_min=0.0,\n",
    "            b_max=1.0,\n",
    "            clip=True,\n",
    "        ),\n",
    "        RandRotateD(\n",
    "            keys=[\"moving_hand\"],\n",
    "            range_x=np.pi / 4,\n",
    "            prob=1.0,\n",
    "            keep_size=True,\n",
    "            mode=\"bicubic\",\n",
    "        ),\n",
    "        RandZoomD(\n",
    "            keys=[\"moving_hand\"],\n",
    "            min_zoom=0.9,\n",
    "            max_zoom=1.1,\n",
    "            prob=1.0,\n",
    "            mode=\"bicubic\",\n",
    "            align_corners=False,\n",
    "        ),\n",
    "        EnsureTypeD(keys=[\"fixed_hand\", \"moving_hand\"]),\n",
    "    ]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01369e3b",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MedNISTDataset(DataInterface):\n",
    "    def __init__(self, **kwargs):\n",
    "        self.kwargs = kwargs\n",
    "\n",
    "    @property\n",
    "    def shard_descriptor(self):\n",
    "        return self._shard_descriptor\n",
    "\n",
    "    @shard_descriptor.setter\n",
    "    def shard_descriptor(self, shard_descriptor):\n",
    "        \"\"\"\n",
    "        Describe per-collaborator procedures or sharding.\n",
    "\n",
    "        This method will be called during a collaborator initialization.\n",
    "        Local shard_descriptor  will be set by Envoy.\n",
    "        \"\"\"\n",
    "        self._shard_descriptor = shard_descriptor\n",
    "\n",
    "        self.train_set = Dataset(\n",
    "            data=self._shard_descriptor.get_dataset(\"train\").data_items,\n",
    "            transform=image_transforms,\n",
    "        )\n",
    "        self.valid_set = Dataset(\n",
    "            data=self._shard_descriptor.get_dataset(\"validation\").data_items,\n",
    "            transform=image_transforms,\n",
    "        )\n",
    "\n",
    "    def get_train_loader(self, **kwargs):\n",
    "        \"\"\"\n",
    "        Output of this method will be provided to tasks with optimizer in contract\n",
    "        \"\"\"\n",
    "        return DataLoader(self.train_set, batch_size=self.kwargs[\"train_bs\"], shuffle=True)\n",
    "\n",
    "    def get_valid_loader(self, **kwargs):\n",
    "        \"\"\"\n",
    "        Output of this method will be provided to tasks without optimizer in contract\n",
    "        \"\"\"\n",
    "        return DataLoader(self.valid_set, batch_size=self.kwargs[\"valid_bs\"])\n",
    "\n",
    "    def get_train_data_size(self):\n",
    "        \"\"\"\n",
    "        Information for aggregation\n",
    "        \"\"\"\n",
    "        return len(self.train_set)\n",
    "\n",
    "    def get_valid_data_size(self):\n",
    "        \"\"\"\n",
    "        Information for aggregation\n",
    "        \"\"\"\n",
    "        return len(self.valid_set)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa78453c",
   "metadata": {},
   "outputs": [],
   "source": [
    "fed_dataset = MedNISTDataset(train_bs=16, valid_bs=16)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "74cac654",
   "metadata": {},
   "source": [
    "### Describe the model and optimizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43e25fe3",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_net = GlobalNet(\n",
    "    image_size=(64, 64),\n",
    "    spatial_dims=2,\n",
    "    in_channels=2,  # moving and fixed\n",
    "    num_channel_initial=16,\n",
    "    depth=3,\n",
    ")\n",
    "\n",
    "image_loss = MSELoss()\n",
    "if USE_COMPILED:\n",
    "    warp_layer = Warp(3, \"border\")\n",
    "else:\n",
    "    warp_layer = Warp(\"bilinear\", \"border\")\n",
    "optimizer_adam = torch.optim.Adam(model_net.parameters(), 1e-5)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f097cdc5",
   "metadata": {},
   "source": [
    "### Register model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06a8cca8",
   "metadata": {},
   "outputs": [],
   "source": [
    "framework_adapter = \"openfl.plugins.frameworks_adapters.pytorch_adapter.FrameworkAdapterPlugin\"\n",
    "model_interface = ModelInterface(model=model_net, optimizer=optimizer_adam, framework_plugin=framework_adapter)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "849c165b",
   "metadata": {},
   "source": [
    "## Define and register FL tasks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9649385",
   "metadata": {},
   "outputs": [],
   "source": [
    "task_interface = TaskInterface()\n",
    "\n",
    "\n",
    "@task_interface.register_fl_task(\n",
    "    model=\"net_model\",\n",
    "    data_loader=\"train_loader\",\n",
    "    device=\"device\",\n",
    "    optimizer=\"optimizer\",\n",
    ")\n",
    "def train(\n",
    "    net_model,\n",
    "    train_loader,\n",
    "    optimizer,\n",
    "    device,\n",
    "    loss_fn=image_loss,\n",
    "    affine_transform=warp_layer,\n",
    "):\n",
    "    train_loader = tqdm.tqdm(train_loader, desc=\"train\")\n",
    "    net_model.train()\n",
    "    net_model.to(device)\n",
    "    warp_layer.to(device)\n",
    "\n",
    "    epoch_loss = 0.0\n",
    "    step = 0\n",
    "\n",
    "    for batch_data in train_loader:\n",
    "        step += 1\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        moving = batch_data[\"moving_hand\"].to(device)\n",
    "        fixed = batch_data[\"fixed_hand\"].to(device)\n",
    "        ddf = net_model(torch.cat((moving, fixed), dim=1))\n",
    "        pred_image = affine_transform(moving, ddf)\n",
    "\n",
    "        loss = loss_fn(pred_image, fixed)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        epoch_loss += loss.item()\n",
    "    epoch_loss /= step\n",
    "    return {\n",
    "        \"train_loss\": epoch_loss,\n",
    "    }\n",
    "\n",
    "\n",
    "@task_interface.register_fl_task(model=\"net_model\", data_loader=\"val_loader\", device=\"device\")\n",
    "def validate(net_model, val_loader, device, loss_fn=image_loss, affine_transform=warp_layer):\n",
    "    net_model.eval()\n",
    "    net_model.to(device)\n",
    "    warp_layer.to(device)\n",
    "\n",
    "    epoch_loss = 0.0\n",
    "    step = 0\n",
    "\n",
    "    val_loader = tqdm.tqdm(val_loader, desc=\"validate\")\n",
    "\n",
    "    with torch.no_grad():\n",
    "        for batch_data in val_loader:\n",
    "            step += 1\n",
    "\n",
    "            moving = batch_data[\"moving_hand\"].to(device)\n",
    "            fixed = batch_data[\"fixed_hand\"].to(device)\n",
    "            ddf = net_model(torch.cat((moving, fixed), dim=1))\n",
    "            pred_image = affine_transform(moving, ddf)\n",
    "            loss = loss_fn(pred_image, fixed)\n",
    "            epoch_loss += loss.item()\n",
    "\n",
    "    epoch_loss /= step\n",
    "    return {\n",
    "        \"validation_loss\": epoch_loss,\n",
    "    }"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8f0ebf2d",
   "metadata": {},
   "source": [
    "## Time to start a federated learning experiment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d41b7896",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create an experimnet in federation\n",
    "experiment_name = \"mednist_experiment\"\n",
    "fl_experiment = FLExperiment(federation=federation, experiment_name=experiment_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41b44de9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# The following command zips the workspace and python requirements to be transfered to collaborator nodes\n",
    "fl_experiment.start(\n",
    "    model_provider=model_interface,\n",
    "    task_keeper=task_interface,\n",
    "    data_loader=fed_dataset,\n",
    "    rounds_to_train=10,\n",
    "    opt_treatment=\"CONTINUE_GLOBAL\",\n",
    "    device_assignment_policy=\"CUDA_PREFERRED\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83edd88f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# To check how experiment is going\n",
    "fl_experiment.stream_metrics(tensorboard_logs=False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "vscode": {
   "interpreter": {
    "hash": "d4d1e4263499bec80672ea0156c357c1ee493ec2b1c70f0acce89fc37c4a6abe"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
